{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "# We don't use the whole dataset for efficiency purpose, but feel free to increase these numbers\n",
    "n_train_items = 640\n",
    "n_test_items = 640"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part X - Secure Training and Evaluation on MNIST\n",
    "\n",
    "When building Machine Learning as a Service solutions (MLaaS), a company might need to request access to data from other partners to train its model. In health or in finance, both the model and the data are extremely critical: the model parameters is a business asset while data is personal data which is tightly regulated.\n",
    "\n",
    "In this context, one possible solution is to encrypt both the model and the data and to train the machine learning model over the encrypted values. This guarantees that the company won't access patients medical records for example and that health facilities won't be able to observe the model to which they contribute. Several encryption schemes exist that allow for computation over encrypted data, among which Secure Multi-Party Computation (SMPC), Homomorphic Encryption (FHE/SHE) and Functional Encryption (FE). We will focus here on Multi-Party Computation (which have been introduced in Tutorial 5) which consists of private additive sharing and relies on the crypto protocols SecureNN and SPDZ.\n",
    "\n",
    "The exact setting of this tutorial is the following: consider that you are the server and you would like to train your model on some data held by $n$ workers. The server secret shares his model and send each share to a worker. The workers also secret share their data and exchange it between them. In the configuration that we will study, there are 2 workers: alice and bob. After exchanging shares, each of them now has one of their own shares, one share of the other worker, and one share of the model. Computation can now start to privately train the model using the appropriate crypto protocols. Once the model is trained, all the shares can be sent back to the server to decrypt it. This is illustrated with the following figure:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![SMPC Illustration](https://github.com/OpenMined/PySyft/raw/11c85a121a1a136e354945686622ab3731246084/examples/tutorials/material/smpc_illustration.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To give an example of this process, let's assume alice and bob both hold a part of the MNIST dataset and let's train a model to perform digit classification!\n",
    "\n",
    "Author:\n",
    "- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel) · GitHub: [@LaRiffle](https://github.com/LaRiffle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Encrypted Training demo on MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports and training configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This class describes all the hyper-parameters for the training. Note that they are all public here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Arguments():\n",
    "    def __init__(self):\n",
    "        self.batch_size = 64\n",
    "        self.test_batch_size = 64\n",
    "        self.epochs = epochs\n",
    "        self.lr = 0.02\n",
    "        self.seed = 1\n",
    "        self.log_interval = 1 # Log info at each batch\n",
    "        self.precision_fractional = 3\n",
    "\n",
    "args = Arguments()\n",
    "\n",
    "_ = torch.manual_seed(args.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here are PySyft imports. We connect to two remote workers that be call `alice` and `bob` and request another worker called the `crypto_provider` who gives all the crypto primitives we may need."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy  # import the Pysyft library\n",
    "hook = sy.TorchHook(torch)  # hook PyTorch to add extra functionalities like Federated and Encrypted Learning\n",
    "\n",
    "# simulation functions\n",
    "def connect_to_workers(n_workers):\n",
    "    return [\n",
    "        sy.VirtualWorker(hook, id=f\"worker{i+1}\")\n",
    "        for i in range(n_workers)\n",
    "    ]\n",
    "def connect_to_crypto_provider():\n",
    "    return sy.VirtualWorker(hook, id=\"crypto_provider\")\n",
    "\n",
    "workers = connect_to_workers(n_workers=2)\n",
    "crypto_provider = connect_to_crypto_provider()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting access and secret share data\n",
    "\n",
    "Here we're using a utility function which simulates the following behaviour: we assume the MNIST dataset is distributed in parts each of which is held by one of our workers. The workers then split their data in batches and secret share their data between each others. The final object returned is an iterable on these secret shared batches, that we call the **private data loader**. Note that during the process the local worker (so us) never had access to the data.\n",
    "\n",
    "We obtain as usual a training and testing private dataset, and both the inputs and labels are secret shared."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_private_data_loaders(precision_fractional, workers, crypto_provider):\n",
    "    \n",
    "    def one_hot_of(index_tensor):\n",
    "        \"\"\"\n",
    "        Transform to one hot tensor\n",
    "        \n",
    "        Example:\n",
    "            [0, 3, 9]\n",
    "            =>\n",
    "            [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
    "             [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],\n",
    "             [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]\n",
    "            \n",
    "        \"\"\"\n",
    "        onehot_tensor = torch.zeros(*index_tensor.shape, 10) # 10 classes for MNIST\n",
    "        onehot_tensor = onehot_tensor.scatter(1, index_tensor.view(-1, 1), 1)\n",
    "        return onehot_tensor\n",
    "        \n",
    "    def secret_share(tensor):\n",
    "        \"\"\"\n",
    "        Transform to fixed precision and secret share a tensor\n",
    "        \"\"\"\n",
    "        return (\n",
    "            tensor\n",
    "            .fix_precision(precision_fractional=precision_fractional)\n",
    "            .share(*workers, crypto_provider=crypto_provider, requires_grad=True)\n",
    "        )\n",
    "    \n",
    "    transformation = transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "    ])\n",
    "    \n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data', train=True, download=True, transform=transformation),\n",
    "        batch_size=args.batch_size\n",
    "    )\n",
    "    \n",
    "    private_train_loader = [\n",
    "        (secret_share(data), secret_share(one_hot_of(target)))\n",
    "        for i, (data, target) in enumerate(train_loader)\n",
    "        if i < n_train_items / args.batch_size\n",
    "    ]\n",
    "    \n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST('../data', train=False, download=True, transform=transformation),\n",
    "        batch_size=args.test_batch_size\n",
    "    )\n",
    "    \n",
    "    private_test_loader = [\n",
    "        (secret_share(data), secret_share(target.float()))\n",
    "        for i, (data, target) in enumerate(test_loader)\n",
    "        if i < n_test_items / args.test_batch_size\n",
    "    ]\n",
    "    \n",
    "    return private_train_loader, private_test_loader\n",
    "    \n",
    "    \n",
    "private_train_loader, private_test_loader = get_private_data_loaders(\n",
    "    precision_fractional=args.precision_fractional,\n",
    "    workers=workers,\n",
    "    crypto_provider=crypto_provider\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model specification\n",
    "\n",
    "Here is the model that we will use, it's a rather simple one but [it has proved to perform reasonably well on MNIST](https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(28 * 28, 128)\n",
    "        self.fc2 = nn.Linear(128, 64)\n",
    "        self.fc3 = nn.Linear(64, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 28 * 28)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training and testing functions\n",
    "\n",
    "The training is done almost as usual, the real difference is that we can't use losses like negative log-likelihood (`F.nll_loss` in PyTorch) because it's quite complicated to reproduce these functions with SMPC. Instead, we use a simpler Mean Square Error loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(args, model, private_train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(private_train_loader): # <-- now it is a private dataset\n",
    "        start_time = time.time()\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        output = model(data)\n",
    "        \n",
    "        # loss = F.nll_loss(output, target)  <-- not possible here\n",
    "        batch_size = output.shape[0]\n",
    "        loss = ((output - target)**2).sum().refresh()/batch_size\n",
    "        \n",
    "        loss.backward()\n",
    "        \n",
    "        optimizer.step()\n",
    "\n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            loss = loss.get().float_precision()\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}\\tTime: {:.3f}s'.format(\n",
    "                epoch, batch_idx * args.batch_size, len(private_train_loader) * args.batch_size,\n",
    "                100. * batch_idx / len(private_train_loader), loss.item(), time.time() - start_time))\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The test function does not change!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, private_test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in private_test_loader:\n",
    "            start_time = time.time()\n",
    "            \n",
    "            output = model(data)\n",
    "            pred = output.argmax(dim=1)\n",
    "            correct += pred.eq(target.view_as(pred)).sum()\n",
    "\n",
    "    correct = correct.get().float_precision()\n",
    "    print('\\nTest set: Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        correct.item(), len(private_test_loader)* args.test_batch_size,\n",
    "        100. * correct.item() / (len(private_test_loader) * args.test_batch_size)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Let's launch the training !\n",
    "\n",
    "A few notes about what's happening here. First, we secret share all the model parameters across our workers. Second, we convert optimizer's hyperparameters to fixed precision. Note that we don't need to secret share them because they are public in our context, but as secret shared values live in finite fields we still need to move them in finite fields using `.fix_precision`, in order to perform consistently operations like the weight update $W \\leftarrow W - \\alpha * \\Delta W$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Net()\n",
    "model = model.fix_precision().share(*workers, crypto_provider=crypto_provider, requires_grad=True)\n",
    "\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr)\n",
    "optimizer = optimizer.fix_precision() \n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(args, model, private_train_loader, optimizer, epoch)\n",
    "    test(args, model, private_test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There you are! You just get 75% of accuracy using a tiny fraction of the MNIST dataset, using 100% encrypted training!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Discussion\n",
    "\n",
    "Let's have a closer look to the power of encrypted training by analyzing what we just did."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1 Computation time\n",
    "\n",
    "First thing is obviously the running time! As you have surely noticed, it is far slower than plain text training. In particular, a iteration over 1 batch of 64 items takes 3.2s while only 13ms in pure PyTorch. Whereas this might seem like a blocker, just recall that here everything happened remotely and in the encrypted world: no single data item has been disclosed. More specifically, the time to process one item is 50ms which is not that bad. The real question is to analyze when encrypted training is needed and when only encrypted prediction is sufficient. 50ms to perform a prediction is completely acceptable in a production-ready scenario for example!\n",
    "\n",
    "One main bottleneck is the use of costly activation functions: relu activation with SMPC are very costly because it uses private comparison and the SecureNN protocol. As an illustration, if we replace relu with a quadratic activation as it is done in several papers on encrypted computation like CryptoNets, we drop from 3.2s to 1.2s.\n",
    "\n",
    "As a general rule, the key idea to take away is to encrypt only what's necessary, and this tutorial shows you how simple it can be."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2 Backpropagation with SMPC\n",
    "\n",
    "You might wonder how we perform backpropagation and gradient updates although we're working with integers in finite fields. To do so, we have developed a new syft tensor called AutogradTensor. This tutorial used it intensively although you might have not seen it! Let's check this by printing a model's weight:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.fc3.bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And a data item:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "first_batch, input_data = 0, 0\n",
    "private_train_loader[first_batch][input_data]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you observe, the AutogradTensor is there! It lives between the torch wrapper and the FixedPrecisionTensor which indicate that the values are now in finite fields. The goal of this AutogradTensor is to store the computation graph when operations are made on encrypted values. This is useful because when calling backward for the backpropagation, this AutogradTensor overrides all the backward functions that are not compatible with encrypted computation and indicates how to compute these gradients. For example, regarding multiplication which is done using the Beaver triples trick, we don't want to differentiate the trick all the more that differentiating a multiplication should be very easy: $\\partial_b (a \\cdot b) = a \\cdot \\partial b$. Here is how we describe how to compute these gradients for example:\n",
    "\n",
    "```python\n",
    "class MulBackward(GradFunc):\n",
    "    def __init__(self, self_, other):\n",
    "        super().__init__(self, self_, other)\n",
    "        self.self_ = self_\n",
    "        self.other = other\n",
    "\n",
    "    def gradient(self, grad):\n",
    "        grad_self_ = grad * self.other\n",
    "        grad_other = grad * self.self_ if type(self.self_) == type(self.other) else None\n",
    "        return (grad_self_, grad_other)\n",
    "```\n",
    "\n",
    "You can have a look at `tensors/interpreters/gradients.py` if you're curious to see how we implemented more gradients.\n",
    "\n",
    "In terms of computation graph, it means that a copy of the graph remains local and that the server which coordinates the forward pass also provide instructions on how to do the backward pass. This is a completely valid hypothesis in our setting."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.3 Security guarantees\n",
    "\n",
    "\n",
    "Last, let's give a few hints about the security we're achieving here: adversaries that we are considering here are **honest but curious**: this means that an adversary can't learn anything about the data by running this protocol, but a malicious adversary could still deviate from the protocol and for example try to corrupt the shares to sabotage the computation. Security against malicious adversaries in such SMPC computations including private comparison is still an open problem.\n",
    "\n",
    "In addition, even if Secure Multi-Party Computation ensures that training data wasn't accessed, many threats from the plain text world are still present here. For example, as you can make request to the model (in the context of MLaaS), you can get predictions which might disclose information about the training dataset. In particular you don't have any protection against membership attacks, a common attack on machine learning services where the adversary wants to determine if a specific item was used in the dataset. Besides this, other attacks such as unintended memorization processes (models learning specific feature about a data item), model inversion or extraction are still possible.\n",
    "\n",
    "One general solution which is effective for many of the threats mentioned above is to add Differential Privacy. It can be nicely combined with Secure Multi-Party Computation and can provide very interesting security guarantees. We're currently working on several implementations and hope to propose an example that combines both shortly!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Conclusion\n",
    "\n",
    "As you have seen, training a model using SMPC is not complicated from a code point of view, even we use rather complex objects under the hood. With this in mind, you should now analyse your use-cases to see when encrypted computation is needed either for training or for evaluation. If encrypted computation is much slower in general, it can also be used carefully so as to reduce the overall computation overhead.\n",
    "\n",
    "If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways! \n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the repositories! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Pick our tutorials on GitHub!\n",
    "\n",
    "We made really nice tutorials to get a better understanding of what Federated and Privacy-Preserving Learning should look like and how we are building the bricks for this to happen.\n",
    "\n",
    "- [Checkout the PySyft tutorials](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials)\n",
    "\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! \n",
    "\n",
    "- [Join slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! If you want to start \"one off\" mini-projects, you can go to PySyft GitHub Issues page and search for issues marked `Good First Issue`.\n",
    "\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "- [Donate through OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}